Pytorch Shift Stack Demo¶

We demonstrate the shift and stack

Loading Data + Preprocessing¶

We first loop through the directories of where the data lives after quering the data

In [1]:
import os
directory = os.fsencode('Data/aligned')
files = []
for sub1 in os.listdir(directory):
    sub1_name = os.fsdecode(sub1)
    if '.fits'in sub1_name:
        files.append('Data/aligned'+"/"+sub1_name)

We can then load the fits and keep track of the times in a dictionary. NOTE: These files are preprocessed and aligned for RA/DEC offsets

In [2]:
import os
from tqdm import tqdm
from astropy.io import fits
from astropy.time import TimeDelta, Time
import torch
import numpy as np
from astropy.io import fits
import matplotlib.pyplot as plt
from astropy.visualization import (ImageNormalize, ZScaleInterval, LogStretch, MinMaxInterval, AsinhStretch)
from astropy.wcs import WCS
from scipy.ndimage import shift
from astropy.coordinates import SkyCoord

total_data = {}
pass_counter = 0
MAGZPL_LIST = []
for i in tqdm(range(len(files))):

    # Open the FITS file
    filename = files[i]
    hdul = fits.open(filename)    
    # Access the primary HDU (Header Data Unit)
    primary_hdu = hdul[0]
    
    # Get the data and header
    # try:
    data = primary_hdu.data
    header = primary_hdu.header

    total_elements = data.size
    # Count the number of NaN values
    nan_count = np.isnan(data).sum()
    # Calculate the percentage of NaN values
    nan_percentage = (nan_count / total_elements) * 100
    if nan_percentage < 20:
        time = header['SHUTOPEN']
        MAGZPL_LIST.append(header['MAGZP'])
        time = Time(time.replace('T',' '), format='iso', scale='utc')
        total_data[time] = torch.tensor(data.astype(np.float32))
    else:
        pass_counter += 1
        pass
    hdul.close()
print(pass_counter, 'skipped out of ', len(files)) 
MAGZPL = np.max(MAGZPL_LIST)
print(MAGZPL)
100%|█████████████████████████████████████████| 939/939 [00:21<00:00, 44.50it/s]
671 skipped out of  939
26.416529

Get the earliest time of observation

In [3]:
sorted_keys = sorted(total_data.keys())
earliest_frame = sorted_keys[0]

Load the earliest data frame and plot what it looks like

In [4]:
# Apply normalization
norm = ImageNormalize(stretch=AsinhStretch(a=0.0001), interval=ZScaleInterval())  
data = total_data[earliest_frame]
# norm = ImageNormalize(total_data[earliest_frame], interval=ZScaleInterval())

# Plot the image with normalization
plt.figure(figsize=(10, 8))
plt.imshow(data, cmap='gray_r', origin='lower', norm=norm)
plt.colorbar(label='Pixel Value')
plt.title('FITS Image with Normalization')
plt.xlabel('RA Pixel')
plt.ylabel('Dec Pixel')
plt.show()
No description has been provided for this image

To verify we have aligned the frames together we plot a movie of the frames

In [5]:
from astropy.io import fits
import matplotlib.pyplot as plt
from astropy.visualization import (ImageNormalize, ZScaleInterval, LogStretch, MinMaxInterval, AsinhStretch)
import numpy as np
import matplotlib.pyplot as plt
from celluloid import Camera
from IPython.display import HTML

norm = ImageNormalize(stretch=AsinhStretch(a=0.0001), interval=ZScaleInterval())  

%matplotlib inline
# Set up the figure and axes
fig, ax = plt.subplots(figsize=(8,8))

# Initialize Celluloid Camera
camera = Camera(fig)


data = total_data[earliest_frame]
count = 0
# Animate the sine wave with a changing phase
for key in sorted_keys:
    ax.imshow(total_data[key], cmap='gray_r', origin='lower', norm=norm)  # Plot the sine wave
    camera.snap()  # Capture the frame
    if count > 100:
        break
    else:
        count += 1

# Create the animation
animation = camera.animate(interval=100)  # 100ms between frames

# Display the animation
HTML(animation.to_jshtml())  # Render as JS HTML
Out[5]:
No description has been provided for this image
No description has been provided for this image

Add fakes¶

We add gaussian dot that shifts across the sky as a function of time.

In [6]:
import numpy as np
import matplotlib.pyplot as plt

def generate_2d_gaussian(grid_size, amplitude=1, x0=0, y0=0, sigma_x=1, sigma_y=1):
    Nx, Ny = grid_size
    x = np.linspace(-Nx // 2, Nx // 2, Nx)
    y = np.linspace(-Ny // 2, Ny // 2, Ny)
    X, Y = np.meshgrid(x, y)
    gaussian = amplitude * np.exp(-(((X - x0) ** 2) / (2 * sigma_x ** 2) +
                                    ((Y - y0) ** 2) / (2 * sigma_y ** 2)))
    return gaussian
In [7]:
ra_0, dec_0 = 0,0 # center of the frame
dx, dy = 1,  2 # we shift up and to the right these are in units of pixel/day for the projected velocities
MAG = 19 # magnitude of the injected star
In [8]:
total_data = {}
pass_counter = 0


for i in tqdm(range(len(files))):

    # Open the FITS file
    filename = files[i]
    hdul = fits.open(filename)    
    # Access the primary HDU (Header Data Unit)
    primary_hdu = hdul[0]
    
    # Get the data and header
    data = primary_hdu.data
    header = primary_hdu.header
    # ====== make sure to keep only images that are NOT all NANs
    total_elements = data.size
    # Count the number of NaN values
    nan_count = np.isnan(data).sum()
    # Calculate the percentage of NaN values
    nan_percentage = (nan_count / total_elements) * 100
    if nan_percentage < 20:
        time = header['SHUTOPEN']
        time = Time(time.replace('T',' '), format='iso', scale='utc')
        # ========== step 1: remove saturated pixels ==========
        saturate = header['SATURATE']
        data = np.where(data > saturate/2, np.nan, data) 
        # =========  step 2: subtract the median (sky subtraction) =======
        median = np.nanmedian(data)
        data = data - median
        data[np.isnan(data)] = 0 # correct nans to zeros
        # =========  step 3: scale flux =============
        MAGZP = header['MAGZP']
        scale = 10**(0.4*(MAGZPL - MAGZP))
        data = data * scale
        # =========== step 4: drop gaussian fake =========
        grid_size = data.shape[::-1] # Grid size (Nx, Ny)
        dt = (time - earliest_frame).to_value('day') 
        x0, y0 = 0 + dt * dx, 0 + dt * dy  # shift the gaussian
        
        FWHM = header['SEEING']
        sigma = FWHM/2.3548
        sigma_x, sigma_y = sigma, sigma  # Standard deviations
        amplitude = 10**(-0.4*(MAG-MAGZPL)) *(1/(2*np.pi*sigma**2))
        gaussian = generate_2d_gaussian(grid_size, amplitude, x0, y0, sigma_x, sigma_y)
        data = np.array(data) + gaussian
        
        header = primary_hdu.header
        
        total_data[time] = torch.tensor(data.astype(np.float32))
    else:
        pass_counter += 1
        pass
    hdul.close()
print(pass_counter, 'skipped out of ', len(files))  
100%|█████████████████████████████████████████| 939/939 [01:07<00:00, 13.82it/s]
671 skipped out of  939

Watch the injections!¶

Let us visualise the fake injected star

In [9]:
from astropy.io import fits
import matplotlib.pyplot as plt
from astropy.visualization import (ImageNormalize, ZScaleInterval, LogStretch, MinMaxInterval, AsinhStretch)
import numpy as np
import matplotlib.pyplot as plt
from celluloid import Camera
from IPython.display import HTML

norm = ImageNormalize(stretch=AsinhStretch(a=0.0001), interval=ZScaleInterval())  

%matplotlib inline
# Set up the figure and axes
fig, ax = plt.subplots(figsize=(10,10))

# Initialize Celluloid Camera
camera = Camera(fig)


data = total_data[earliest_frame]
count = 0
# Animate the sine wave with a changing phase
for key in sorted_keys:
    ax.imshow(total_data[key], cmap='gray_r', origin='lower', norm=norm)  # Plot the sine wave
    camera.snap()  # Capture the frame
    if count > 100:
        break
    else:
        count += 1

# Create the animation
animation = camera.animate(interval=100)  # 100ms between frames

# Display the animation
HTML(animation.to_jshtml())  # Render as JS HTML
Out[9]:
No description has been provided for this image
No description has been provided for this image

Shift and Add¶

We write the batched version of shift and add model

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time

class Shift_Stack(nn.Module):
    def __init__(self, frames, times):
        super(Shift_Stack, self).__init__()
        self.frames = self.frame_subtraction(frames).cuda()[:,None,:,:]
        self.delta_times = torch.zeros_like(times).cuda()
        self.delta_times[1:] = times[1:] - times[0]
        self.B = self.frames.shape[0] # predefine batch sizes (number of frames)
        self.grid = F.affine_grid(
            torch.eye(2, 3).unsqueeze(0).repeat(self.B, 1, 1),  # Identity matrix for rotation and scaling
            size=self.frames.size(),  # Output size
            align_corners=False
        ).cuda()
        
    # @torch.jit.script
    def frame_subtraction(self, frame):
        """
        Sequential frame subtractions
        """
        frame = frame - frame[0,:,:]
        return frame
        
    def decimate_32_to_16(self, data_float32):
        """
        Decimate data from float32 to float16 while preserving the range
        """
        # Define a scaling factor to bring values within float16 range
        scale_factor = data_float32.abs().max() / 6.55e4  # Approx. max for float16
        scale_factor = max(scale_factor, 1.0)  # Ensure we only scale down
        
        # Scale, convert to float16, and scale back
        data_scaled = data_float32 / scale_factor
        data_float16 = data_scaled.to(dtype=torch.float16)
        data_preserved = data_float16 * scale_factor
        return data_preserved
        
    def average_nonzeros(self, data, axis=0):
        """
        This averages only the nonzero values
        """
        # Mask to identify non-zero values
        non_zero_sum = torch.sum(data, dim=axis)
        non_zero_count = torch.count_nonzero(data, dim=axis)
        
        # Avoid division by zero
        non_zero_count = torch.where(non_zero_count == 0, torch.tensor(1.0, device=data.device), non_zero_count)
        
        # Calculate average of non-zero elements
        non_zero_avg = non_zero_sum / non_zero_count
        return non_zero_avg

    def shift_images(self, images, shifts):
        """
        Shifts a batch of images by specified (x, y) values for each image.
        
        images (torch.Tensor): A batch of images of shape (B, C, H, W),
                               where B is batch size, C is number of channels,
                               H is height, and W is width.
        shifts (torch.Tensor): A tensor of shape (B, 2), where each row
                               represents the (B, x, y) shift for the corresponding image.
        """
        B, C, H, W = images.shape
        shifts_x, shifts_y = shifts[:, 0], shifts[:, 1]
        # Create a grid for affine transformations
        start = time()
        grid = self.grid.clone()
        # Adjust the grid by the shifts
        grid[..., 0] += 2 * shifts_x.view(-1, 1, 1) / W  # Normalize shift in x direction
        grid[..., 1] += 2 * shifts_y.view(-1, 1, 1) / H  # Normalize shift in y direction
        start = time()
        # Perform the grid sampling to shift the images
        shifted_images = F.grid_sample(
            images, grid, mode='bilinear', padding_mode='zeros', align_corners=False
        ).cuda()
        return shifted_images

    def forward(self, dxdy):
        """
        Preprocessing handing before calling shift_images
        """
        dx, dy = dxdy[:,0], dxdy[:,1]
        shifts_x = torch.multiply(dx , self.delta_times) # pixel shift in x
        shifts_y = torch.multiply(dy , self.delta_times) # pixel shift in y
        shifts = torch.stack([shifts_x, shifts_y] , axis = 1)
        frames = self.frames
        batched_shift = self.shift_images(frames, shifts)
        batched_shift = batched_shift[:,0,:,: ] # meant for future multi-filters... [curr. index zero for 1 filter]
        # stacked = torch.mean(batched_shift, axis = 0) # FASTER MEAN Calculations
        stacked = torch.median(batched_shift, axis = 0)
        # stacked = self.average_nonzeros(batched_shift, axis = 0)
        return stacked

Load up the data into torch tensors.

In [11]:
frames = []
frame_times = []
adjusts = []
for i, key in enumerate(sorted_keys):
    if i < 150:
        frames.append(total_data[key])
        frame_times.append(key.to_value('jd'))
frames = torch.stack(frames)
frame_times = torch.tensor(frame_times)
print(frames.shape, frame_times.shape)
torch.Size([150, 3080, 3072]) torch.Size([150])

Instantiate a SINGLE shift stack model for 1 single frame

In [12]:
stacking_job = Shift_Stack(frames, frame_times)

Run a job on shift and stacking for 1 sequence of orbits

In [13]:
shifts = torch.ones((150, 2)).cuda()
shifts[:, 0] = dx
shifts[:, 1] = dy
print(shifts.shape)
torch.Size([150, 2])

Shifts data object is in the following shape $[time, spatial]$. Here there are 100 frames at different times, each frame has a corresponding 2d adjustment vector. For us since the velocity is constant in time we do not need to worry about different dx dy as a function of time.

In [14]:
start = time()
result = stacking_job(shifts)
print(time()-start, "runtime [s]")
0.11995458602905273 runtime [s]
In [17]:
norm = ImageNormalize(stretch=AsinhStretch(a=0.0001), interval=ZScaleInterval())  

plt.figure(figsize=(10, 8))
plt.imshow(result.values.cpu().numpy(), cmap='gray_r', origin='lower',  norm = norm)
plt.colorbar(label='Pixel Value')
plt.title('FITS Image with Normalization')
plt.xlabel('RA Pixel')
plt.ylabel('Dec Pixel')
plt.show()
No description has been provided for this image